{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "04c5d23d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizerFast(name_or_path='hfl/rbt3', vocab_size=21128, model_max_len=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/加载编码器\n",
    "from transformers import AutoTokenizer\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained('hfl/rbt3')\n",
    "\n",
    "tokenizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "1a374f76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 海 钓 比 赛 地 点 在 厦 门 与 金 门 之 间 的 海 域 。 [SEP]\n",
      "[CLS] 这 座 依 山 傍 水 的 博 物 馆 由 国 内 一 流 的 设 计 [SEP]\n",
      "input_ids tensor([[ 101, 3862, 7157, 3683, 6612, 1765, 4157, 1762, 1336, 7305,  680, 7032,\n",
      "         7305,  722, 7313, 4638, 3862, 1818,  511,  102],\n",
      "        [ 101, 6821, 2429,  898, 2255,  988, 3717, 4638, 1300, 4289, 7667, 4507,\n",
      "         1744, 1079,  671, 3837, 4638, 6392, 6369,  102]])\n",
      "token_type_ids tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],\n",
      "        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])\n",
      "attention_mask tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],\n",
      "        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])\n"
     ]
    }
   ],
   "source": [
    "#第10章/编码测试\n",
    "out = tokenizer.batch_encode_plus(\n",
    "    [[\n",
    "        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',\n",
    "        '的', '海', '域', '。'\n",
    "    ],\n",
    "     [\n",
    "         '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',\n",
    "         '流', '的', '设', '计', '师', '主', '持', '设', '计', '。'\n",
    "     ]],\n",
    "    truncation=True,\n",
    "    padding=True,\n",
    "    return_tensors='pt',\n",
    "    max_length=20,\n",
    "    is_split_into_words=True)\n",
    "\n",
    "#还原编码为句子\n",
    "print(tokenizer.decode(out['input_ids'][0]))\n",
    "print(tokenizer.decode(out['input_ids'][1]))\n",
    "\n",
    "for k, v in out.items():\n",
    "    print(k, v)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "a356986c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间', '的', '海', '域', '。']\n",
      "[0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 5, 6, 0, 0, 0, 0, 0, 0]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "20865"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/定义数据集\n",
    "import torch\n",
    "from datasets import load_dataset, load_from_disk\n",
    "\n",
    "\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "\n",
    "        #在线加载数据集\n",
    "        #dataset = load_dataset(path='peoples_daily_ner', split=split)\n",
    "\n",
    "        #离线加载数据集\n",
    "        dataset = load_from_disk(\n",
    "            dataset_path='./data/peoples_daily_ner')[split]\n",
    "\n",
    "        self.dataset = dataset\n",
    "\n",
    "        #dataset.features['ner_tags'].feature.num_classes\n",
    "        #7\n",
    "\n",
    "        #dataset.features['ner_tags'].feature.names\n",
    "        #['O', 'B-PER', 'I-PER', 'B-ORG', 'I-ORG', 'B-LOC', 'I-LOC']\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        tokens = self.dataset[i]['tokens']\n",
    "        labels = self.dataset[i]['ner_tags']\n",
    "\n",
    "        return tokens, labels\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "tokens, labels = dataset[0]\n",
    "print(tokens), print(labels)\n",
    "\n",
    "len(dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "eac759df",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "# for i in range(100, 200):\n",
    "#     text = dataset[i][0][:15]\n",
    "#     label = dataset[i][1][:15]\n",
    "\n",
    "#     if sum(label) == 0:\n",
    "#         continue\n",
    "\n",
    "#     text = ' '.join(text)\n",
    "#     label = ' '.join([str(j) for j in label])\n",
    "\n",
    "#     print(text)\n",
    "#     print(label)\n",
    "#     print()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "0579bb63",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'cuda'"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/定义计算设备\n",
    "device = 'cpu'\n",
    "if torch.cuda.is_available():\n",
    "    device = 'cuda'\n",
    "\n",
    "device"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "ca81efee",
   "metadata": {},
   "outputs": [],
   "source": [
    "#第10章/定义数据整理函数\n",
    "def collate_fn(data):\n",
    "    tokens = [i[0] for i in data]\n",
    "    labels = [i[1] for i in data]\n",
    "\n",
    "    #编码\n",
    "    inputs = tokenizer.batch_encode_plus(tokens,\n",
    "                                         truncation=True,\n",
    "                                         padding=True,\n",
    "                                         return_tensors='pt',\n",
    "                                         max_length=512,\n",
    "                                         is_split_into_words=True)\n",
    "\n",
    "    #求一批数据中最长的句子长度\n",
    "    lens = inputs['input_ids'].shape[1]\n",
    "\n",
    "    #在labels的头尾补充7，把所有的labels补充成统一的长度\n",
    "    for i in range(len(labels)):\n",
    "        labels[i] = [7] + labels[i]\n",
    "        labels[i] += [7] * lens\n",
    "        labels[i] = labels[i][:lens]\n",
    "\n",
    "    #把编码结果移动到计算设备\n",
    "    for k, v in inputs.items():\n",
    "        inputs[k] = v.to(device)\n",
    "\n",
    "    #把统一长度的labels组装成矩阵，并移动到计算设备\n",
    "    labels = torch.LongTensor(labels).to(device)\n",
    "\n",
    "    return inputs, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ef2f8d9e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids torch.Size([2, 37])\n",
      "token_type_ids torch.Size([2, 37])\n",
      "attention_mask torch.Size([2, 37])\n",
      "labels torch.Size([2, 37])\n"
     ]
    }
   ],
   "source": [
    "#第10章/数据整理函数试算\n",
    "#模拟一批数据\n",
    "data = [\n",
    "    ([\n",
    "        '海', '钓', '比', '赛', '地', '点', '在', '厦', '门', '与', '金', '门', '之', '间',\n",
    "        '的', '海', '域', '。'\n",
    "    ], [0, 0, 0, 0, 0, 0, 0, 5, 6, 0, 5, 6, 0, 0, 0, 0, 0, 0]),\n",
    "    ([\n",
    "        '这', '座', '依', '山', '傍', '水', '的', '博', '物', '馆', '由', '国', '内', '一',\n",
    "        '流', '的', '设', '计', '师', '主', '持', '设', '计', '，', '整', '个', '建', '筑',\n",
    "        '群', '精', '美', '而', '恢', '宏', '。'\n",
    "    ], [\n",
    "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
    "        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0\n",
    "    ]),\n",
    "]\n",
    "\n",
    "#试算\n",
    "inputs, labels = collate_fn(data)\n",
    "\n",
    "for k, v in inputs.items():\n",
    "    print(k, v.shape)\n",
    "\n",
    "print('labels', labels.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "696bea03",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1304"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "len(loader)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5cc26170",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS] 不 过 ， 从 此 以 后 ， 惠 普 公 司 便 成 了 他 们 的 忠 诚 客 户 。 [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]\n",
      "tensor([7, 0, 0, 0, 0, 0, 0, 0, 0, 3, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,\n",
      "        7, 7, 7], device='cuda:0')\n",
      "input_ids torch.Size([16, 171])\n",
      "token_type_ids torch.Size([16, 171])\n",
      "attention_mask torch.Size([16, 171])\n"
     ]
    }
   ],
   "source": [
    "#第10章/查看数据样例\n",
    "for i, (inputs, labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(tokenizer.decode(inputs['input_ids'][0]))\n",
    "print(labels[0])\n",
    "\n",
    "for k, v in inputs.items():\n",
    "    print(k, v.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "421a37b0",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at hfl/rbt3 were not used when initializing BertModel: ['cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3847.68\n"
     ]
    }
   ],
   "source": [
    "#第10章/加载预训练模型\n",
    "from transformers import AutoModel\n",
    "\n",
    "pretrained = AutoModel.from_pretrained('hfl/rbt3')\n",
    "\n",
    "pretrained.to(device)\n",
    "\n",
    "#统计参数量\n",
    "print(sum(i.numel() for i in pretrained.parameters()) / 10000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "id": "4932d723",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 171, 768])"
      ]
     },
     "execution_count": 11,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/模型试算\n",
    "#[b, lens] -> [b, lens, 768]\n",
    "pretrained(**inputs).last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "id": "b530230e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 171, 8])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/定义下游模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        #标识当前模型是否处于tuning模式\n",
    "        self.tuning = False\n",
    "        #当处于tuning模式时backbone应该属于当前模型的一部分，否则该变量为空\n",
    "        self.pretrained = None\n",
    "\n",
    "        #当前模型的神经网络层\n",
    "        self.rnn = torch.nn.GRU(input_size=768, hidden_size=768, batch_first=True)\n",
    "        self.fc = torch.nn.Linear(in_features=768, out_features=8)\n",
    "\n",
    "    def forward(self, inputs):\n",
    "        #根据当前模型是否处于tuning模式而使用外部backbone或内部backbone计算\n",
    "        if self.tuning:\n",
    "            out = self.pretrained(**inputs).last_hidden_state\n",
    "        else:\n",
    "            with torch.no_grad():\n",
    "                out = pretrained(**inputs).last_hidden_state\n",
    "\n",
    "        #backbone抽取的特征输入rnn网络进一步抽取特征\n",
    "        out, _ = self.rnn(out)\n",
    "\n",
    "        #rnn网络抽取的特征最后输入fc神经网络分类\n",
    "        out = self.fc(out).softmax(dim=2)\n",
    "\n",
    "        return out\n",
    "\n",
    "    #切换下游任务模型的的tuning模式\n",
    "    def fine_tuning(self, tuning):\n",
    "        self.tuning = tuning\n",
    "        #tuning模式时，训练backbone的参数\n",
    "        if tuning:\n",
    "            for i in pretrained.parameters():\n",
    "                i.requires_grad = True\n",
    "\n",
    "            pretrained.train()\n",
    "            self.pretrained = pretrained\n",
    "        #非tuning模式时，不训练backbone的参数\n",
    "        else:\n",
    "            for i in pretrained.parameters():\n",
    "                i.requires_grad_(False)\n",
    "\n",
    "            pretrained.eval()\n",
    "            self.pretrained = None\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model.to(device)\n",
    "\n",
    "model(inputs).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "3a2c01e4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([[-0.0526,  0.1013, -0.0173, -0.7658,  1.5053,  0.7575,  0.0168, -0.2533],\n",
       "         [-0.9728,  1.1428, -0.8979, -1.1862, -0.6049,  0.4607,  0.2360,  1.0206],\n",
       "         [-0.3074, -0.1413,  0.7964,  1.0810,  1.0588,  2.8770,  0.5829, -0.0668],\n",
       "         [-0.4274, -0.0931, -1.0921, -0.7383, -0.0292, -0.6288, -3.0649, -0.3452],\n",
       "         [-1.5576, -0.7772, -1.2155, -0.6495, -0.1605,  2.0787, -0.1997, -0.1986],\n",
       "         [ 0.1147,  1.3831, -0.1156,  0.8515, -0.3147,  0.7072, -0.4293,  0.9322]]),\n",
       " tensor([1., 1., 1., 1., 1., 1.]))"
      ]
     },
     "execution_count": 13,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/对计算结果和label变形,并且移除pad\n",
    "def reshape_and_remove_pad(outs, labels, attention_mask):\n",
    "    #变形,便于计算loss\n",
    "    #[b, lens, 8] -> [b*lens, 8]\n",
    "    outs = outs.reshape(-1, 8)\n",
    "    #[b, lens] -> [b*lens]\n",
    "    labels = labels.reshape(-1)\n",
    "\n",
    "    #忽略对pad的计算结果\n",
    "    #[b, lens] -> [b*lens - pad]\n",
    "    select = attention_mask.reshape(-1) == 1\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "\n",
    "    return outs, labels\n",
    "\n",
    "\n",
    "reshape_and_remove_pad(torch.randn(2, 3, 8), torch.ones(2, 3),\n",
    "                       torch.ones(2, 3))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "00ebbb80",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(1, 16, 1, 16)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#第10章/获取正确数量和总数\n",
    "def get_correct_and_total_count(labels, outs):\n",
    "    #[b*lens, 8] -> [b*lens]\n",
    "    outs = outs.argmax(dim=1)\n",
    "    correct = (outs == labels).sum().item()\n",
    "    total = len(labels)\n",
    "\n",
    "    #计算除了0以外元素的正确率,因为0太多了,包括的话,正确率很容易虚高\n",
    "    select = labels != 0\n",
    "    outs = outs[select]\n",
    "    labels = labels[select]\n",
    "    correct_content = (outs == labels).sum().item()\n",
    "    total_content = len(labels)\n",
    "\n",
    "    return correct, total, correct_content, total_content\n",
    "\n",
    "\n",
    "get_correct_and_total_count(torch.ones(16), torch.randn(16, 8))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "b6c80630",
   "metadata": {
    "scrolled": false
   },
   "outputs": [],
   "source": [
    "#第10章/训练\n",
    "from transformers import AdamW\n",
    "from transformers.optimization import get_scheduler\n",
    "\n",
    "\n",
    "def train(epochs):\n",
    "    lr = 2e-5 if model.tuning else 5e-4\n",
    "\n",
    "    optimizer = AdamW(model.parameters(), lr=lr)\n",
    "    criterion = torch.nn.CrossEntropyLoss()\n",
    "    scheduler = get_scheduler(name='linear',\n",
    "                              num_warmup_steps=0,\n",
    "                              num_training_steps=len(loader) * epochs,\n",
    "                              optimizer=optimizer)\n",
    "\n",
    "    model.train()\n",
    "    for epoch in range(epochs):\n",
    "        for step, (inputs, labels) in enumerate(loader):\n",
    "            #模型计算\n",
    "            #[b, lens] -> [b, lens, 8]\n",
    "            outs = model(inputs)\n",
    "\n",
    "            #对outs和label变形,并且移除pad\n",
    "            #outs -> [b, lens, 8] -> [c, 8]\n",
    "            #labels -> [b, lens] -> [c]\n",
    "            outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                                  inputs['attention_mask'])\n",
    "\n",
    "            #梯度下降\n",
    "            loss = criterion(outs, labels)\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            scheduler.step()\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            if step % (len(loader) * epochs // 30) == 0:\n",
    "                counts = get_correct_and_total_count(labels, outs)\n",
    "                accuracy = counts[0] / counts[1]\n",
    "                accuracy_content = counts[2] / counts[3]\n",
    "                lr = optimizer.state_dict()['param_groups'][0]['lr']\n",
    "\n",
    "                print(epoch, step, loss.item(), lr, accuracy, accuracy_content)\n",
    "\n",
    "        torch.save(model, 'model/中文命名实体识别.model')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "1165e878",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "354.9704\n"
     ]
    }
   ],
   "source": [
    "#第10章/两段式训练第1步，训练下游任务模型\n",
    "model.fine_tuning(False)\n",
    "print(sum(p.numel() for p in model.parameters()) / 10000)\n",
    "#train(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "6c234616",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "4202.6504\n"
     ]
    }
   ],
   "source": [
    "#第10章/两段式训练第2步，同时训练下游任务模型和预训练模型\n",
    "model.fine_tuning(True)\n",
    "print(sum(p.numel() for p in model.parameters()) / 10000)\n",
    "#train(5)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "0942c3a5",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.9923819197562215 0.9607843137254902\n"
     ]
    }
   ],
   "source": [
    "#第10章/测试\n",
    "def test():\n",
    "    #加载训练完的模型\n",
    "    model_load = torch.load('model/中文命名实体识别.model')\n",
    "    model_load.tuning = True\n",
    "    model_load.eval()\n",
    "    model_load.to(device)\n",
    "\n",
    "    #测试数据集加载器\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),\n",
    "                                              batch_size=128,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    correct_content = 0\n",
    "    total_content = 0\n",
    "\n",
    "    #遍历测试数据集\n",
    "    for step, (inputs, labels) in enumerate(loader_test):\n",
    "\n",
    "        #测试5个批次即可，不全部全部遍历\n",
    "        if step == 5:\n",
    "            break\n",
    "        print(step)\n",
    "\n",
    "        #计算\n",
    "        with torch.no_grad():\n",
    "            #[b, lens] -> [b, lens, 8] -> [b, lens]\n",
    "            outs = model_load(inputs)\n",
    "\n",
    "        #对outs和label变形,并且移除pad\n",
    "        #outs -> [b, lens, 8] -> [c, 8]\n",
    "        #labels -> [b, lens] -> [c]\n",
    "        outs, labels = reshape_and_remove_pad(outs, labels,\n",
    "                                              inputs['attention_mask'])\n",
    "\n",
    "        #统计正确数量\n",
    "        counts = get_correct_and_total_count(labels, outs)\n",
    "        correct += counts[0]\n",
    "        total += counts[1]\n",
    "        correct_content += counts[2]\n",
    "        total_content += counts[3]\n",
    "\n",
    "    print(correct / total, correct_content / total_content)\n",
    "\n",
    "\n",
    "test()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "fb9fa55c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[CLS]要建立健全信息网络，增强工作的预见性、前瞻性。[SEP]\n",
      "[CLS]7·······················[SEP]7\n",
      "[CLS]7·······················[SEP]7\n",
      "==========================\n",
      "[CLS]该书是迄今为止国内较为全面、系统地研究周恩来经济思想的一部专著，有助于深化对周恩来思想的研究。[SEP]\n",
      "[CLS]7···················周1恩2来2················周1恩2来2······[SEP]7\n",
      "[CLS]7···················周1恩2来2················周1恩2来2······[SEP]7\n",
      "==========================\n",
      "[CLS]会议之后，国际原油市场仍明显供大于求，价格十分疲软。[SEP]\n",
      "[CLS]7··························[SEP]7\n",
      "[CLS]7··························[SEP]7\n",
      "==========================\n",
      "[CLS]两个月后少女平静地离去，她的身边簇拥着俊平的朋友们，枕边还放着俊平为她捎去的书。[SEP]\n",
      "[CLS]7···················俊1平2··········俊1平2·······[SEP]7\n",
      "[CLS]7···················俊1平2··········俊1平2·······[SEP]7\n",
      "==========================\n",
      "[CLS]第二，在改制后的国有大中型企业设置党组。[SEP]\n",
      "[CLS]7····················[SEP]7\n",
      "[CLS]7····················[SEP]7\n",
      "==========================\n",
      "[CLS]中央政府严格遵守《基本法》，对于香港首届立法会选举事务未予任何干预。[SEP]\n",
      "[CLS]7················香3港4首4届4立4法4会4···········[SEP]7\n",
      "[CLS]7················香3港4首4届4立4法4会4···········[SEP]7\n",
      "==========================\n",
      "[CLS]对于发展中国家而言，汽车工业的集中化趋势显然是严峻的挑战。[SEP]\n",
      "[CLS]7·····························[SEP]7\n",
      "[CLS]7·····························[SEP]7\n",
      "==========================\n",
      "[CLS]时隔不久，中央电视台也对武汉电视台敞开了大门，把他们纳入了[UNK]科普宣传国家队[UNK]的行列。[SEP]\n",
      "[CLS]7·····中3央4电4视4台4··武3汉4电4视4台4·············科3普4宣4传4国4家4队4·····[SEP]7\n",
      "[CLS]7·····中3央4电4视4台4··武3汉4电4视4台4··············普4宣4传4国4家4队4·····[SEP]7\n",
      "==========================\n",
      "[CLS]俄罗斯总统叶利钦8日在这里明确表示反对外国对目前的科索沃冲突进行军事干涉，说外部力量的介入可能会导致这一地区冲突进一步扩大。[SEP]\n",
      "[CLS]7俄5罗6斯6··叶1利2钦2·················科5索6沃6··································[SEP]7\n",
      "[CLS]7俄5罗6斯6··叶1利2钦2·················科5索6沃6··································[SEP]7\n",
      "==========================\n",
      "[CLS]我每次来北京开会，彭真同志总要安排时间，同我交谈，询问西藏工作情况，了解有什么问题和实际需要。[SEP]\n",
      "[CLS]7····北5京6···彭1真2················西5藏6··················[SEP]7\n",
      "[CLS]7····北5京6···彭1真2················西5藏6··················[SEP]7\n",
      "==========================\n",
      "[CLS]社论说，中美两国首脑宣布互不将各自控制下的战略核武器瞄准对方，并就防止核扩散和裁减大规模杀伤性武器的问题达成了协议，这将在寻求和平的国际政治史上留下重要的一笔。[SEP]\n",
      "[CLS]7····中5美5··········································································[SEP]7\n",
      "[CLS]7····中5美5··········································································[SEP]7\n",
      "==========================\n",
      "[CLS]面向下岗职工，社区服务走出新天地[SEP]\n",
      "[CLS]7················[SEP]7\n",
      "[CLS]7················[SEP]7\n",
      "==========================\n",
      "[CLS]致公党中央主席罗豪才主持会议并讲话，强调致公党要切实搞好领导班子建设，领导班子做到[UNK]眼界开阔[UNK]、[UNK]胸襟宽广[UNK]。[SEP]\n",
      "[CLS]7致3公4党4中4央4··罗1豪2才2··········致3公4党4································[SEP]7\n",
      "[CLS]7致3公4党4中4央4··罗1豪2才2··········致3公3党4································[SEP]7\n",
      "==========================\n",
      "[CLS]全国人大常委会委员、法律委员会副主任委员李伯勇、全国人大常委会副秘书长吕聪敏和英国驻华使节到机场送行。[SEP]\n",
      "[CLS]7全3国4人4大4常4委4会4···法3律4委4员4会4·····李1伯2勇2·全3国4人4大4常4委4会4····吕1聪2敏2·英5国6·华5········[SEP]7\n",
      "[CLS]7全3国4人4大4常4委4会4···法3律4委4员4会4·····李1伯2勇2·全3国4人4大4常4委4会4····吕1聪2敏2·英5国6驻4华4使4节4······[SEP]7\n",
      "==========================\n",
      "[CLS]1958年我毕业后被分配到民歌之乡的山西工作。[SEP]\n",
      "[CLS]7··················山5西6···[SEP]7\n",
      "[CLS]7··················山5西6···[SEP]7\n",
      "==========================\n",
      "[CLS]用爱，去编织温馨[UNK][UNK][UNK]记四川有线电视台妇女儿童频道[SEP]\n",
      "[CLS]7············四3川4有4线4电4视4台4······[SEP]7\n",
      "[CLS]7············四3川4有4线4电4视4台4······[SEP]7\n",
      "==========================\n",
      "[CLS]栾城县委、县政府的做法推动了全县干部工作作风的转变，受到群众的欢迎。[SEP]\n",
      "[CLS]7栾3城4县4委4······························[SEP]7\n",
      "[CLS]7栾3城4县4委4······························[SEP]7\n",
      "==========================\n",
      "[CLS]他说，在亚太地区安全、防止大规模杀伤性武器扩散、打击国际犯罪、保护环境、双边贸易与能源合作等各个领域，美国都需要与中国合作。[SEP]\n",
      "[CLS]7····亚5太6·············································美5国6····中5国6···[SEP]7\n",
      "[CLS]7····亚5太6·············································美5国6····中5国6···[SEP]7\n",
      "==========================\n",
      "[CLS]双方还将扩大在科技、文化、教育、艺术等领域的相互交流与合作。[SEP]\n",
      "[CLS]7······························[SEP]7\n",
      "[CLS]7······························[SEP]7\n",
      "==========================\n",
      "[CLS]阿尔及利亚的天然气产量占非洲天然气产量的68％。[SEP]\n",
      "[CLS]7阿5尔6及6利6亚6·······非5洲6··········[SEP]7\n",
      "[CLS]7阿5尔6及6利6亚6·······非5洲6··········[SEP]7\n",
      "==========================\n",
      "[CLS]的确，这是人们都很关注的问题。[SEP]\n",
      "[CLS]7···············[SEP]7\n",
      "[CLS]7···············[SEP]7\n",
      "==========================\n",
      "[CLS]本报新德里5月28日电记者李文云报道：今晚7时20分，印度外交部就巴基斯坦进行核试验发表正式声明。[SEP]\n",
      "[CLS]7··新5德6里6········李1文2云2···········印3度4外4交4部4·巴5基6斯6坦6············[SEP]7\n",
      "[CLS]7··新5德6里6········李1文2云2···········印3度4外4交4部4·巴5基6斯6坦6············[SEP]7\n",
      "==========================\n",
      "[CLS]但是，外国金融机构大规模进军日本金融业可以带来新的经营方式，增强日本金融界的活力。[SEP]\n",
      "[CLS]7··············日5本6················日5本6·······[SEP]7\n",
      "[CLS]7··············日5本6················日5本6·······[SEP]7\n",
      "==========================\n",
      "[CLS]1998年度[UNK]春晖计划[UNK]还列出了资助的主要领域有：能源、电子信息技术、计算机技术、生物技术、农业可持续发展等11个方面；主要专业方向有：电子学与信息系统、微电子技术、生物医学工程、煤炭利用工程等16个方面。[SEP]\n",
      "[CLS]7·······································································································[SEP]7\n",
      "[CLS]7·······································································································[SEP]7\n",
      "==========================\n",
      "[CLS]80年代中期，拥有丰富的国际金融商战经验的中国银行率先在国内开办了代客债务风险管理业务，其目的就是防范风险，控制成本。[SEP]\n",
      "[CLS]7·····················中3国4银4行4··································[SEP]7\n",
      "[CLS]7·····················中3国4银4行4··································[SEP]7\n",
      "==========================\n",
      "[CLS]瓦杰帕伊称，现在应讨论签订一项全球核裁军公约，所有那些拥有核武器的国家都应承担裁减核武器的义务。[SEP]\n",
      "[CLS]7瓦1杰2帕2伊2············································[SEP]7\n",
      "[CLS]7瓦1杰2帕2伊2············································[SEP]7\n",
      "==========================\n",
      "[CLS]5月23日是墨西哥全国防疫接种卫生周第一天。[SEP]\n",
      "[CLS]7······墨5西6哥6·············[SEP]7\n",
      "[CLS]7······墨5西6哥6·············[SEP]7\n",
      "==========================\n",
      "[CLS]在这次由印度核试验引发的紧张局势中，双方在克什米尔问题上，既有[UNK]舌战[UNK]，也有[UNK]热战[UNK]，双方在实际控制线地区曾有过交火。[SEP]\n",
      "[CLS]7····印5度6···············克5什6米6尔6··································[SEP]7\n",
      "[CLS]7····印5度6···············克5什6米6尔6··································[SEP]7\n",
      "==========================\n",
      "[CLS]是不是越便宜越有益我们这些工薪阶层？[SEP]\n",
      "[CLS]7··················[SEP]7\n",
      "[CLS]7··················[SEP]7\n",
      "==========================\n",
      "[CLS]必须广泛吸收现代经济理论的成果并予以创新，才能把握转轨经济中的问题，统率所掌握的资料。[SEP]\n",
      "[CLS]7···········································[SEP]7\n",
      "[CLS]7···········································[SEP]7\n",
      "==========================\n",
      "[CLS]在德班，姆贝基副总统在发表[UNK]自由日[UNK]演讲时指出，300多年种族隔离制度遗留下来的问题，不可能在短短4年的时间里就能解决，南非仍然存在着种族矛盾、贫富差别和诸多不安定因素。[SEP]\n",
      "[CLS]7·德5班6·姆1贝2基2·····················································南5非6·······················[SEP]7\n",
      "[CLS]7·德5班6·姆1贝2基2·····················································南5非6·······················[SEP]7\n",
      "==========================\n",
      "[CLS]中、斯联合登山队将认真总结第一次攀登失利的经验，根据队员的身体状况，伺机进行第二次突击顶峰。[SEP]\n",
      "[CLS]7中3、4斯4联4合4登4山4队4······································[SEP]7\n",
      "[CLS]7中3·斯3联4合4登4山4队4······································[SEP]7\n",
      "==========================\n"
     ]
    }
   ],
   "source": [
    "#第10章/预测\n",
    "def predict():\n",
    "    #加载模型\n",
    "    model_load = torch.load('model/中文命名实体识别.model')\n",
    "    model_load.tuning = True\n",
    "    model_load.eval()\n",
    "    model_load.to(device)\n",
    "\n",
    "    #测试数据集加载器\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    #取一个批次的数据\n",
    "    for i, (inputs, labels) in enumerate(loader_test):\n",
    "        break\n",
    "\n",
    "    #计算\n",
    "    with torch.no_grad():\n",
    "        #[b, lens] -> [b, lens, 8] -> [b, lens]\n",
    "        outs = model_load(inputs).argmax(dim=2)\n",
    "\n",
    "    for i in range(32):\n",
    "        #移除pad\n",
    "        select = inputs['attention_mask'][i] == 1\n",
    "        input_id = inputs['input_ids'][i, select]\n",
    "        out = outs[i, select]\n",
    "        label = labels[i, select]\n",
    "\n",
    "        #输出原句子\n",
    "        print(tokenizer.decode(input_id).replace(' ', ''))\n",
    "\n",
    "        #输出tag\n",
    "        for tag in [label, out]:\n",
    "            s = ''\n",
    "            for j in range(len(tag)):\n",
    "                if tag[j] == 0:\n",
    "                    s += '·'\n",
    "                    continue\n",
    "                s += tokenizer.decode(input_id[j])\n",
    "                s += str(tag[j].item())\n",
    "\n",
    "            print(s)\n",
    "        print('==========================')\n",
    "\n",
    "\n",
    "predict()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
